{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.utils import shuffle\n",
    "import re\n",
    "import time\n",
    "import collections\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words, atleast=1):\n",
    "    count = [['PAD', 0], ['GO', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    counter = collections.Counter(words).most_common(n_words)\n",
    "    counter = [i for i in counter if i[1] >= atleast]\n",
    "    count.extend(counter)\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "lines = open('movie_lines.txt', encoding='utf-8', errors='ignore').read().split('\\n')\n",
    "conv_lines = open('movie_conversations.txt', encoding='utf-8', errors='ignore').read().split('\\n')\n",
    "\n",
    "id2line = {}\n",
    "for line in lines:\n",
    "    _line = line.split(' +++$+++ ')\n",
    "    if len(_line) == 5:\n",
    "        id2line[_line[0]] = _line[4]\n",
    "        \n",
    "convs = [ ]\n",
    "for line in conv_lines[:-1]:\n",
    "    _line = line.split(' +++$+++ ')[-1][1:-1].replace(\"'\",\"\").replace(\" \",\"\")\n",
    "    convs.append(_line.split(','))\n",
    "    \n",
    "questions = []\n",
    "answers = []\n",
    "\n",
    "for conv in convs:\n",
    "    for i in range(len(conv)-1):\n",
    "        questions.append(id2line[conv[i]])\n",
    "        answers.append(id2line[conv[i+1]])\n",
    "        \n",
    "def clean_text(text):\n",
    "    text = text.lower()\n",
    "    text = re.sub(r\"i'm\", \"i am\", text)\n",
    "    text = re.sub(r\"he's\", \"he is\", text)\n",
    "    text = re.sub(r\"she's\", \"she is\", text)\n",
    "    text = re.sub(r\"it's\", \"it is\", text)\n",
    "    text = re.sub(r\"that's\", \"that is\", text)\n",
    "    text = re.sub(r\"what's\", \"that is\", text)\n",
    "    text = re.sub(r\"where's\", \"where is\", text)\n",
    "    text = re.sub(r\"how's\", \"how is\", text)\n",
    "    text = re.sub(r\"\\'ll\", \" will\", text)\n",
    "    text = re.sub(r\"\\'ve\", \" have\", text)\n",
    "    text = re.sub(r\"\\'re\", \" are\", text)\n",
    "    text = re.sub(r\"\\'d\", \" would\", text)\n",
    "    text = re.sub(r\"\\'re\", \" are\", text)\n",
    "    text = re.sub(r\"won't\", \"will not\", text)\n",
    "    text = re.sub(r\"can't\", \"cannot\", text)\n",
    "    text = re.sub(r\"n't\", \" not\", text)\n",
    "    text = re.sub(r\"n'\", \"ng\", text)\n",
    "    text = re.sub(r\"'bout\", \"about\", text)\n",
    "    text = re.sub(r\"'til\", \"until\", text)\n",
    "    text = re.sub(r\"[-()\\\"#/@;:<>{}`+=~|.!?,]\", \"\", text)\n",
    "    return ' '.join([i.strip() for i in filter(None, text.split())])\n",
    "\n",
    "clean_questions = []\n",
    "for question in questions:\n",
    "    clean_questions.append(clean_text(question))\n",
    "    \n",
    "clean_answers = []    \n",
    "for answer in answers:\n",
    "    clean_answers.append(clean_text(answer))\n",
    "    \n",
    "min_line_length = 2\n",
    "max_line_length = 5\n",
    "short_questions_temp = []\n",
    "short_answers_temp = []\n",
    "\n",
    "i = 0\n",
    "for question in clean_questions:\n",
    "    if len(question.split()) >= min_line_length and len(question.split()) <= max_line_length:\n",
    "        short_questions_temp.append(question)\n",
    "        short_answers_temp.append(clean_answers[i])\n",
    "    i += 1\n",
    "\n",
    "short_questions = []\n",
    "short_answers = []\n",
    "\n",
    "i = 0\n",
    "for answer in short_answers_temp:\n",
    "    if len(answer.split()) >= min_line_length and len(answer.split()) <= max_line_length:\n",
    "        short_answers.append(answer)\n",
    "        short_questions.append(short_questions_temp[i])\n",
    "    i += 1\n",
    "\n",
    "question_test = short_questions[500:550]\n",
    "answer_test = short_answers[500:550]\n",
    "short_questions = short_questions[:500]\n",
    "short_answers = short_answers[:500]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 657\n",
      "Most common words [('you', 132), ('is', 78), ('i', 68), ('what', 51), ('it', 50), ('that', 49)]\n",
      "Sample data [7, 28, 129, 35, 61, 42, 12, 22, 82, 225] ['what', 'good', 'stuff', 'she', 'okay', 'they', 'do', 'to', 'hey', 'sweet']\n",
      "filtered vocab size: 661\n",
      "% of vocab used: 100.61%\n"
     ]
    }
   ],
   "source": [
    "concat_from = ' '.join(short_questions+question_test).split()\n",
    "vocabulary_size_from = len(list(set(concat_from)))\n",
    "data_from, count_from, dictionary_from, rev_dictionary_from = build_dataset(concat_from, vocabulary_size_from)\n",
    "print('vocab from size: %d'%(vocabulary_size_from))\n",
    "print('Most common words', count_from[4:10])\n",
    "print('Sample data', data_from[:10], [rev_dictionary_from[i] for i in data_from[:10]])\n",
    "print('filtered vocab size:',len(dictionary_from))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary_from)/vocabulary_size_from,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 660\n",
      "Most common words [('i', 97), ('you', 91), ('is', 62), ('it', 58), ('not', 47), ('what', 39)]\n",
      "Sample data [12, 216, 5, 4, 94, 25, 59, 10, 8, 79] ['the', 'real', 'you', 'i', 'hope', 'so', 'they', 'do', 'not', 'hi']\n",
      "filtered vocab size: 664\n",
      "% of vocab used: 100.61%\n"
     ]
    }
   ],
   "source": [
    "concat_to = ' '.join(short_answers+answer_test).split()\n",
    "vocabulary_size_to = len(list(set(concat_to)))\n",
    "data_to, count_to, dictionary_to, rev_dictionary_to = build_dataset(concat_to, vocabulary_size_to)\n",
    "print('vocab from size: %d'%(vocabulary_size_to))\n",
    "print('Most common words', count_to[4:10])\n",
    "print('Sample data', data_to[:10], [rev_dictionary_to[i] for i in data_to[:10]])\n",
    "print('filtered vocab size:',len(dictionary_to))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary_to)/vocabulary_size_to,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary_from['GO']\n",
    "PAD = dictionary_from['PAD']\n",
    "EOS = dictionary_from['EOS']\n",
    "UNK = dictionary_from['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(short_answers)):\n",
    "    short_answers[i] += ' EOS'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Chatbot:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, \n",
    "                 from_dict_size, to_dict_size, learning_rate, \n",
    "                 batch_size, dropout = 0.5, beam_width = 15):\n",
    "        \n",
    "        def lstm_cell(size, reuse=False):\n",
    "            return tf.nn.rnn_cell.BasicRNNCell(size, reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype=tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        # encoder\n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([from_dict_size, embedded_size], -1, 1))\n",
    "        encoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        for n in range(num_layers):\n",
    "            (out_fw, out_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(\n",
    "                cell_fw = lstm_cell(size_layer // 2),\n",
    "                cell_bw = lstm_cell(size_layer // 2),\n",
    "                inputs = encoder_embedded,\n",
    "                sequence_length = self.X_seq_len,\n",
    "                dtype = tf.float32,\n",
    "                scope = 'bidirectional_rnn_%d'%(n))\n",
    "            encoder_embedded = tf.concat((out_fw, out_bw), 2)\n",
    "        \n",
    "        bi_state = tf.concat((state_fw, state_bw), -1)\n",
    "        self.encoder_state = tuple([bi_state] * num_layers)\n",
    "        \n",
    "        self.encoder_state = tuple(self.encoder_state[-1] for _ in range(num_layers))\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        # decoder\n",
    "        decoder_embeddings = tf.Variable(tf.random_uniform([to_dict_size, embedded_size], -1, 1))\n",
    "        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([lstm_cell(size_layer) for _ in range(num_layers)])\n",
    "        dense_layer = tf.layers.Dense(to_dict_size)\n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(decoder_embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = training_helper,\n",
    "                initial_state = self.encoder_state,\n",
    "                output_layer = dense_layer)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                embedding = decoder_embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS)\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = predicting_helper,\n",
    "                initial_state = self.encoder_state,\n",
    "                output_layer = dense_layer)\n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        self.predicting_ids = predicting_decoder_output.sample_id\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 256\n",
    "num_layers = 2\n",
    "embedded_size = 128\n",
    "learning_rate = 0.001\n",
    "batch_size = 16\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-8-e5df6339b591>:7: BasicRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.SimpleRNNCell, and will be replaced by that in Tensorflow 2.0.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Chatbot(size_layer, num_layers, embedded_size, len(dictionary_from), \n",
    "                len(dictionary_to), learning_rate,batch_size)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i.split():\n",
    "            ints.append(dic.get(k,UNK))\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = str_idx(short_questions, dictionary_from)\n",
    "Y = str_idx(short_answers, dictionary_to)\n",
    "X_test = str_idx(question_test, dictionary_from)\n",
    "Y_test = str_idx(answer_test, dictionary_from)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens\n",
    "\n",
    "def check_accuracy(logits, Y):\n",
    "    acc = 0\n",
    "    for i in range(logits.shape[0]):\n",
    "        internal_acc = 0\n",
    "        count = 0\n",
    "        for k in range(len(Y[i])):\n",
    "            try:\n",
    "                if Y[i][k] == logits[i][k]:\n",
    "                    internal_acc += 1\n",
    "                count += 1\n",
    "                if Y[i][k] == EOS:\n",
    "                    break\n",
    "            except:\n",
    "                break\n",
    "        acc += (internal_acc / count)\n",
    "    return acc / logits.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, avg loss: 5.429090, avg accuracy: 0.233640\n",
      "epoch: 2, avg loss: 4.471541, avg accuracy: 0.281546\n",
      "epoch: 3, avg loss: 3.948216, avg accuracy: 0.322755\n",
      "epoch: 4, avg loss: 3.356941, avg accuracy: 0.405858\n",
      "epoch: 5, avg loss: 2.764096, avg accuracy: 0.515825\n",
      "epoch: 6, avg loss: 2.265020, avg accuracy: 0.603198\n",
      "epoch: 7, avg loss: 1.768222, avg accuracy: 0.713267\n",
      "epoch: 8, avg loss: 1.275978, avg accuracy: 0.846005\n",
      "epoch: 9, avg loss: 0.862886, avg accuracy: 0.934014\n",
      "epoch: 10, avg loss: 0.572476, avg accuracy: 0.990927\n",
      "epoch: 11, avg loss: 0.376501, avg accuracy: 1.005328\n",
      "epoch: 12, avg loss: 0.261636, avg accuracy: 1.009624\n",
      "epoch: 13, avg loss: 0.185355, avg accuracy: 1.009712\n",
      "epoch: 14, avg loss: 0.135875, avg accuracy: 1.010056\n",
      "epoch: 15, avg loss: 0.107704, avg accuracy: 1.011005\n",
      "epoch: 16, avg loss: 0.091278, avg accuracy: 1.010056\n",
      "epoch: 17, avg loss: 0.079851, avg accuracy: 1.010560\n",
      "epoch: 18, avg loss: 0.071846, avg accuracy: 1.009586\n",
      "epoch: 19, avg loss: 0.065002, avg accuracy: 1.009558\n",
      "epoch: 20, avg loss: 0.059749, avg accuracy: 1.009586\n"
     ]
    }
   ],
   "source": [
    "for i in range(epoch):\n",
    "    total_loss, total_accuracy = 0, 0\n",
    "    for k in range(0, len(short_questions), batch_size):\n",
    "        index = min(k+batch_size, len(short_questions))\n",
    "        batch_x, seq_x = pad_sentence_batch(X[k: index], PAD)\n",
    "        batch_y, seq_y = pad_sentence_batch(Y[k: index ], PAD)\n",
    "        predicted, accuracy,loss, _ = sess.run([model.predicting_ids, \n",
    "                                                model.accuracy, model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss += loss\n",
    "        total_accuracy += accuracy\n",
    "    total_loss /= (len(short_questions) / batch_size)\n",
    "    total_accuracy /= (len(short_questions) / batch_size)\n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(i+1, total_loss, total_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "row 1\n",
      "QUESTION: i am a werewolf\n",
      "REAL ANSWER: a werewolf\n",
      "PREDICTED ANSWER: a werewolf \n",
      "\n",
      "row 2\n",
      "QUESTION: i was dreaming again\n",
      "REAL ANSWER: i would think so\n",
      "PREDICTED ANSWER: i would think so \n",
      "\n",
      "row 3\n",
      "QUESTION: the kitchen\n",
      "REAL ANSWER: very nice\n",
      "PREDICTED ANSWER: very nice \n",
      "\n",
      "row 4\n",
      "QUESTION: the bedroom\n",
      "REAL ANSWER: there is only one bed\n",
      "PREDICTED ANSWER: there is only one bed \n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(batch_x)):\n",
    "    print('row %d'%(i+1))\n",
    "    print('QUESTION:',' '.join([rev_dictionary_from[n] for n in batch_x[i] if n not in [0,1,2,3]]))\n",
    "    print('REAL ANSWER:',' '.join([rev_dictionary_to[n] for n in batch_y[i] if n not in[0,1,2,3]]))\n",
    "    print('PREDICTED ANSWER:',' '.join([rev_dictionary_to[n] for n in predicted[i] if n not in[0,1,2,3]]),'\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "row 1\n",
      "QUESTION: but david\n",
      "REAL ANSWER: is here that\n",
      "PREDICTED ANSWER: i am so not \n",
      "\n",
      "row 2\n",
      "QUESTION: hopeless it is hopeless\n",
      "REAL ANSWER: tell ballet then back\n",
      "PREDICTED ANSWER: i am afraid \n",
      "\n",
      "row 3\n",
      "QUESTION: miss price\n",
      "REAL ANSWER: yes learning\n",
      "PREDICTED ANSWER: do you have a good \n",
      "\n",
      "row 4\n",
      "QUESTION: mr kessler wake up please\n",
      "REAL ANSWER: is here are\n",
      "PREDICTED ANSWER: i am so \n",
      "\n",
      "row 5\n",
      "QUESTION: there were witnesses\n",
      "REAL ANSWER: why she out\n",
      "PREDICTED ANSWER: yes the money is good \n",
      "\n",
      "row 6\n",
      "QUESTION: what about it\n",
      "REAL ANSWER: not you are\n",
      "PREDICTED ANSWER: i do not know \n",
      "\n",
      "row 7\n",
      "QUESTION: go on ask them\n",
      "REAL ANSWER: i just home\n",
      "PREDICTED ANSWER: and it \n",
      "\n",
      "row 8\n",
      "QUESTION: beware the moon\n",
      "REAL ANSWER: seen hi is he\n",
      "PREDICTED ANSWER: it is harcourt \n",
      "\n",
      "row 9\n",
      "QUESTION: did you hear that\n",
      "REAL ANSWER: is down what\n",
      "PREDICTED ANSWER: the sound \n",
      "\n",
      "row 10\n",
      "QUESTION: i heard that\n",
      "REAL ANSWER: it here not\n",
      "PREDICTED ANSWER: the verona \n",
      "\n",
      "row 11\n",
      "QUESTION: the hound of the baskervilles\n",
      "REAL ANSWER: heard\n",
      "PREDICTED ANSWER: very nice \n",
      "\n",
      "row 12\n",
      "QUESTION: it is moving\n",
      "REAL ANSWER: not you hear\n",
      "PREDICTED ANSWER: it is a secret \n",
      "\n",
      "row 13\n",
      "QUESTION: nice doggie good boy\n",
      "REAL ANSWER: bill stupid\n",
      "PREDICTED ANSWER: it is okay \n",
      "\n",
      "row 14\n",
      "QUESTION: it sounds far away\n",
      "REAL ANSWER: that pecos baby seen hi\n",
      "PREDICTED ANSWER: what is \n",
      "\n",
      "row 15\n",
      "QUESTION: debbie klein cried a lot\n",
      "REAL ANSWER: is will srai not\n",
      "PREDICTED ANSWER: where it \n",
      "\n",
      "row 16\n",
      "QUESTION: what are you doing here\n",
      "REAL ANSWER: is know look i\n",
      "PREDICTED ANSWER: you have got a call \n",
      "\n"
     ]
    }
   ],
   "source": [
    "batch_x, seq_x = pad_sentence_batch(X_test[:batch_size], PAD)\n",
    "batch_y, seq_y = pad_sentence_batch(Y_test[:batch_size], PAD)\n",
    "predicted = sess.run(model.predicting_ids, feed_dict={model.X:batch_x})\n",
    "\n",
    "for i in range(len(batch_x)):\n",
    "    print('row %d'%(i+1))\n",
    "    print('QUESTION:',' '.join([rev_dictionary_from[n] for n in batch_x[i] if n not in [0,1,2,3]]))\n",
    "    print('REAL ANSWER:',' '.join([rev_dictionary_to[n] for n in batch_y[i] if n not in[0,1,2,3]]))\n",
    "    print('PREDICTED ANSWER:',' '.join([rev_dictionary_to[n] for n in predicted[i] if n not in[0,1,2,3]]),'\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
